ENH: Add native save and read support for SSD#13718
ENH: Add native save and read support for SSD#13718Aniketsy wants to merge 8 commits intomne-tools:mainfrom
Conversation
mne/decoding/ssd.py
Outdated
| def _create_cov_callable(self): | ||
| """Recreate covariance callable after initialization or loading.""" | ||
| self.cov_callable = partial( | ||
| _ssd_estimate, | ||
| reg=self.reg, | ||
| cov_method_params=self.cov_method_params, | ||
| info=self.info, | ||
| picks=self.picks, | ||
| n_fft=self.n_fft, | ||
| filt_params_signal=self.filt_params_signal, | ||
| filt_params_noise=self.filt_params_noise, | ||
| rank=self.rank, | ||
| sort_by_spectral_ratio=self.sort_by_spectral_ratio, | ||
| ) | ||
| self.mod_ged_callable = _ssd_mod |
There was a problem hiding this comment.
Personally I would prefer that this method return the partial object and _ssd_mod variable, and leave it to the call site to set the cov_callable and mod_ged_callable attributes. The reason being that, it is not clear for the reader what self._create_cov_callable() is doing, or that it is mutating the class instance. In fact, you could make this a function rather than a method, so that it is clear from the call site what needs to be passed in in order to get the output.
There was a problem hiding this comment.
Moving this out of __init__ and having this as a method was my suggestion, since we would now also need to create cov_callable for __setstate__, but you're right that its behaviour is obscure.
+1 for returning the functions rather than mutating the instance.
There was a problem hiding this comment.
thanks for the review, and correction
tsbinns
left a comment
There was a problem hiding this comment.
Really strong start, thanks @Aniketsy! Have some comments below for suggested next steps. Only skimmed the tests, but will have a proper look once the changes below are addressed.
Just to note, the idea would also be to add this functionality to the other decoding classes SPoC, CSP and XdawnTransformer, but that should be a very simple copy-paste job once the template is set in stone for SSD since they all inheret from the _GEDTransformer parent class.
mne/decoding/ssd.py
Outdated
| filt_params_signal, | ||
| filt_params_noise, | ||
| filt_params_signal=None, | ||
| filt_params_noise=None, |
There was a problem hiding this comment.
These should stay as not having a default. Having a default None could lead users to assume the class will work without specifying them, but that is not the case. Would be better to leave as having no default, and just pass a placeholder value (e.g., None) when we need to init from a state dict.
mne/decoding/ssd.py
Outdated
| saved_version = state.get("mne_version") | ||
| if saved_version is not None and saved_version != _mne_version: | ||
| warn( | ||
| f"The SSD object was saved with MNE-Python {saved_version} but is " | ||
| f"being loaded with {_mne_version}. This may cause issues." | ||
| ) |
There was a problem hiding this comment.
This shouldn't be necessary. If we made changes to the SSD class in a future version, we would try to do so in a way that is backwards-compatible and only give a warning like this if we knew with more certainty that some things might not behave as expected.
mne/decoding/ssd.py
Outdated
| state.pop("mod_ged_callable", None) | ||
| state["_ssd_state"] = True | ||
| state["class_name"] = "SSD" | ||
| state["mne_version"] = _mne_version |
There was a problem hiding this comment.
See other comment on versioning that we wouldn't normally add this info.
mne/decoding/ssd.py
Outdated
| state = self.__dict__.copy() | ||
| state.pop("cov_callable", None) | ||
| state.pop("mod_ged_callable", None) | ||
| state["_ssd_state"] = True |
There was a problem hiding this comment.
This shouldn't be necessary to add. It should be sufficient for any checks (e.g., when initing from it) that the state dict has all the required attr keys.
mne/decoding/ssd.py
Outdated
| def _create_cov_callable(self): | ||
| """Recreate covariance callable after initialization or loading.""" |
There was a problem hiding this comment.
| def _create_cov_callable(self): | |
| """Recreate covariance callable after initialization or loading.""" | |
| def _create_callables(self): | |
| """Create covariance callable on initialization or state loading.""" |
Since this is creating both cov_callable and mod_ged_callable, I think a better name for the method would be something like _create_callables. Also, a suggestion for a nitpicky change to the docstring.
There was a problem hiding this comment.
Consider also this alternative suggestion of returning the functions rather than assigning them to the instance within the method: #13718 (comment)
mne/decoding/ssd.py
Outdated
| required_keys = ( | ||
| "info", | ||
| "filt_params_signal", | ||
| "filt_params_noise", | ||
| "n_components", | ||
| "filters_", | ||
| "patterns_", | ||
| ) | ||
| missing = [k for k in required_keys if k not in info] | ||
| if missing: | ||
| raise ValueError( | ||
| "If 'info' is a dict, it must be a serialized SSD state " | ||
| f"(missing keys: {missing}). " | ||
| "Otherwise pass an mne.Info object." | ||
| ) |
There was a problem hiding this comment.
A few things:
- I think rather than clutter
__init__, it would be cleaner to house this in__setstate__. - I think that all of the expected state dict keys should be considered required, and checked for.
- For the error message, we don't need users to know that they can instantiate the
SSDclass from a state dict. Rather, if theinfothey pass is a dict and it doesn't match the expected state dict format, we should raise aTypeErrorand tell them they need to pass anInfoobject.mne.utils.check._validate_typeis an easy way to check for this.
There was a problem hiding this comment.
thanks for the pointers.
mne/decoding/ssd.py
Outdated
| with open(fname, "wb") as fid: | ||
| pickle.dump(state, fid) |
There was a problem hiding this comment.
Rather than dumping to a pickle file, it would be good to maintain consistency with how this is handled, e.g., for mne.time_frequency.Spectrum objects (i.e., save in HDF5 format):
mne-python/mne/time_frequency/spectrum.py
Lines 939 to 957 in 03436eb
In check_fname, the filetype param could be "ssd" instead of "spectrum".
|
Ah whoops, was mid-review and didn't see your comments before @scott-huberty! |
|
scott-huberty, tsbinns, larsoner thanks for the review and clarification, I’ve addressed most of the points mentioned during the review and am revisiting the changes to ensure I didn’t miss any points.
sure, that sounds good. We can proceed with extending this to the other classes once the |
|
Ah, I only ran the tests locally. I’ll make sure to build the docs locally as well before pushing the changes in future. |
|
@tsbinns I've added |
mne/decoding/ssd.py
Outdated
| if "info" in state and not isinstance(state["info"], Info): | ||
| state["info"] = Info(state["info"]) |
There was a problem hiding this comment.
An "info" entry should always be in the state dict. And does the state dict ever contain an instantiated Info object?
I would think it's not needed to nest this in an if statement, just L483 alone should work, no?
There was a problem hiding this comment.
An
"info"entry should always be in the state dict. And does the state dict ever contain an instantiatedInfoobject?I would think it's not needed to nest this in an if statement, just L483 alone should work, no?
yes, i've just looked into this and I agree, my previous assumptions of removing this was wrong so will update. i'll move the Info reconstruction to __setstate__ rather than read_ssd, following the same pattern as spectrum.__setstate__
There was a problem hiding this comment.
FAILED mne/decoding/tests/test_ssd.py::test_sklearn_compliance[SSD(filt_params_noise={'h_freq':40.0,'l_freq':0.0},filt_params_signal={'h_freq':30.0,'l_freq':0.0},info=100.0)-check_estimators_pickle] - TypeError: mne._fiff.meas_info.Info() argument after ** must be a mapping, not float
FAILED mne/decoding/tests/test_ssd.py::test_sklearn_compliance[SSD(filt_params_noise={'h_freq':40.0,'l_freq':0.0},filt_params_signal={'h_freq':30.0,'l_freq':0.0},info=100.0)-check_estimators_pickle(readonly_memmap=True)] - TypeError: mne._fiff.meas_info.Info() argument after ** must be a mapping, not float
The sklearn pickle test is failing because SSD accepts info as a float , but __setstate__ was unconditionally calling Info(**state["info"]) which crashes on a float. we need to add an isinstance(state["info"], dict) guard tofix our crash.
mne/decoding/ssd.py
Outdated
| See Also | ||
| -------- | ||
| SSD.save |
There was a problem hiding this comment.
Should add the full path mne.decoding.SSD.save.
mne/decoding/tests/test_ssd.py
Outdated
| fname_rt = tmp_path / "test_ssd_rt.h5" | ||
| ssd.save(fname_rt) | ||
| ssd_rt = read_ssd(fname_rt) | ||
| assert_array_almost_equal(ssd.filters_, ssd_rt.filters_) | ||
| assert_array_almost_equal(ssd.transform(X), ssd_rt.transform(X)) |
There was a problem hiding this comment.
What is the purpose of this extra check? Is this not already tested?
There was a problem hiding this comment.
sorry, yes this is redundant check, I'll remove this.
|
@tsbinns thanks! I've addressed the comment you mentioned, should I extend this framework to other classes now. |
| "The state may be from an incompatible version of MNE." | ||
| ) | ||
| if state["info"] is not None: | ||
| state["info"] = Info(**state["info"]) |
There was a problem hiding this comment.
@tsbinns Kept the if guard in CSP, SPoC, and XdawnTransformer because info can be None for these classes, please let me know what you think here.
Fixes #13328